from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *

#
# This section runs an example numerical reconstruction
# to report in the supplement

probe_maxk = 128

# Set GPU to false if you don't have a GPU
GPU = True

field_size = int(np.ceil(probe_maxk)) * 2
# This sets the focal spot size of our BLR probe
probe_fov_radius = (field_size * 4)//10

# This defines that we do our error comparison within the central
# half of the image.
error_check_padding = 1/ 4


bs_factors = np.linspace(0,5/6,6)


resolution_ratios = [0.4,0.6,0.8]
full_pixel_intensity_stack = []
for resolution_ratio in resolution_ratios:
    print('resolution ratio', resolution_ratio)
    # This defines a reconstruction with R = 0.5
    resolution = int(probe_maxk * resolution_ratio) * 2
    pixel_intensity_stack = []
    for bs_factor in bs_factors:
        print('beamstop factor', bs_factor)
        repeat_intensities = []
        nrepeats = int(200 * 0.8 / resolution_ratio**2)
        for repeat in range(nrepeats):
            print('repeat',repeat,'of',nrepeats, end='\r')
            # This generates an ideal BLR probe with the given parameters
            nearfield = generate_blr_probe([field_size,field_size],
                                           probe_fov_radius, 
                                           probe_maxk,
                                           beamstop_factor=bs_factor)
            pixel_intensities = np.zeros((resolution, resolution))
            original_image = np.zeros((resolution, resolution), dtype=np.complex64)
            original_image = complex_to_torch(original_image).to(dtype=t.float32)
            # All this does is upsample the probe in real space to avoid
            # aliasing in the probe-object interaction
            simulation_nearfield = expand_probe(nearfield, original_image)

            if GPU==True:
                simulation_nearfield = simulation_nearfield.to(device='cuda:0')
                original_image = original_image.to(device='cuda:0')
                
            for pixel in range(resolution**2):
                # We generate an image with a single pixel lit up
                if pixel >=1 :
                    original_image.view(resolution**2,2)[pixel-1,0] = 0
                    
                original_image.view(resolution**2,2)[pixel,0] = 1
        
                # Now we perform the simulation
                exit_wave = interact(original_image, simulation_nearfield)
                intensities = cabssq(exit_wave)
                pixel_intensities.ravel()[pixel] = np.float32(t.sum(intensities).cpu().numpy())
                #print(pixel_intensities.ravel()[pixel])

            repeat_intensities.append(pixel_intensities)
        pixel_intensity_stack.append(repeat_intensities)
    full_pixel_intensity_stack.append(pixel_intensity_stack)


results = {'resolution ratios': resolution_ratios, 'beamstop_factors': bs_factors, 'intensities': full_pixel_intensity_stack}
with open('data/uniformity trial.pickle', 'wb') as f:
    pickle.dump(results, f)
